-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
DaCe Orchestration for the Diffusion Granule #514
base: main
Are you sure you want to change the base?
Conversation
cscs-ci run default |
cscs-ci run dace |
cscs-ci run default |
launch jenkins spack |
cscs-ci run default |
launch jenkins spack |
cscs-ci run dace |
cscs-ci run default |
launch jenkins spack |
cscs-ci run dace |
cscs-ci run default |
launch jenkins spack |
cscs-ci run dace |
cscs-ci run default |
launch jenkins spack |
cscs-ci run dace |
cscs-ci run default |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I only looked at the decorator.py
module and tried to provide as much information and suggestions as possible to make your life easier and try to get this done and merged ASAP.
model/atmosphere/diffusion/src/icon4py/model/atmosphere/diffusion/diffusion.py
Outdated
Show resolved
Hide resolved
def orchestrate(func: Callable | None = None, *, method: bool | None = None): | ||
def _decorator(fuse_func: Callable): | ||
compiled_sdfgs = {} # Caching | ||
|
||
def wrapper(*args, **kwargs): | ||
if settings.dace_orchestration is not None: | ||
if "dace" not in settings.backend.name.lower(): | ||
raise ValueError( | ||
"DaCe Orchestration works only with DaCe backends. Change the backend to a DaCe supported one." | ||
) | ||
|
||
if method: | ||
# self is used to retrieve the _exchange object -on the fly halo exchanges- and the grid object -offset providers- | ||
self = args[0] | ||
self_name = next(iter(inspect.signature(fuse_func).parameters)) | ||
else: | ||
raise ValueError( | ||
"The orchestration decorator is only for methods -at least for now-." | ||
) | ||
|
||
fuse_func_orig_annotations = copy.deepcopy(fuse_func.__annotations__) | ||
fuse_func.__annotations__ = to_dace_annotations( | ||
fuse_func | ||
) # every arg/kwarg is annotated with DaCe data types | ||
|
||
exchange_obj = None | ||
grid = None | ||
for attr_name, attr_value in self.__dict__.items(): | ||
if isinstance(attr_value, decomposition.ExchangeRuntime): | ||
exchange_obj = getattr(self, attr_name) | ||
if isinstance(attr_value, icon_grid.IconGrid): | ||
grid = getattr(self, attr_name) | ||
|
||
if not grid: | ||
raise ValueError("No grid object found.") | ||
|
||
order_kwargs_by_annotations(fuse_func, kwargs) | ||
|
||
compile_time_args_kwargs = {} | ||
all_args_kwargs = [*args, *kwargs.values()] | ||
for i, (k, v) in enumerate(fuse_func.__annotations__.items()): | ||
if v is dace.compiletime: | ||
compile_time_args_kwargs[k] = all_args_kwargs[i] | ||
|
||
unique_id = make_uid(fuse_func, compile_time_args_kwargs, exchange_obj) | ||
|
||
default_build_folder = Path(".dacecache") / f"uid_{unique_id}" | ||
|
||
parse_compile_cache_sdfg( | ||
unique_id, | ||
compiled_sdfgs, | ||
default_build_folder, | ||
exchange_obj, | ||
fuse_func, | ||
compile_time_args_kwargs, | ||
self_name, | ||
simplify_fused_sdfg=True, | ||
) | ||
dace_program = compiled_sdfgs[unique_id]["dace_program"] | ||
sdfg = compiled_sdfgs[unique_id]["sdfg"] | ||
compiled_sdfg = compiled_sdfgs[unique_id]["compiled_sdfg"] | ||
|
||
# update the args/kwargs with runtime related values, such as | ||
# concretized symbols, runtime connectivity tables, GHEX C++ pointers, and DaCe structures pointers | ||
updated_args, updated_kwargs = mod_xargs_for_dace_structures( | ||
fuse_func, fuse_func_orig_annotations, args, kwargs | ||
) | ||
updated_kwargs = { | ||
**updated_kwargs, | ||
**dace_specific_kwargs(exchange_obj, grid.offset_providers), | ||
} | ||
updated_kwargs = { | ||
**updated_kwargs, | ||
**dace_symbols_concretization( | ||
grid, fuse_func, fuse_func_orig_annotations, args, kwargs | ||
), | ||
} | ||
# | ||
|
||
sdfg_args = dace_program._create_sdfg_args(sdfg, updated_args, updated_kwargs) | ||
if method: | ||
del sdfg_args[self_name] | ||
|
||
fuse_func.__annotations__ = ( | ||
fuse_func_orig_annotations # restore the original annotations | ||
) | ||
|
||
with dace.config.temporary_config(): | ||
dace.config.Config.set( | ||
"compiler", "allow_view_arguments", value=True | ||
) # Allow numpy views as arguments: If true, allows users to call DaCe programs with NumPy views (for example, “A[:,1]” or “w.T”) | ||
return compiled_sdfg(**sdfg_args) | ||
else: | ||
return fuse_func(*args, **kwargs) | ||
|
||
return wrapper | ||
|
||
return _decorator(func) if func else _decorator |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would push as much as possible work from the run-time function wrapper to the decorator, since the decorator is only executed once at decoration time. I think there were still some errors in the logic to sort the kwargs in case some of them are missing at runtime and I would move the cache lookup out of the parse_compile
function, since I don't think it belongs there and forces to make an extra function call for no reason at runtime.
Here is a possible rewrite of the function addressing some of the issues. I think there is still room for further improvements and cleanups at the bottom of the wrapper, where the dace compilation and some other processing seems to happen but I don't know enough about DaCe to judge if all the operations there need to be repeated at every function call.
def orchestrate(func: Callable | None = None, *, method: bool | None = None): | |
def _decorator(fuse_func: Callable): | |
compiled_sdfgs = {} # Caching | |
def wrapper(*args, **kwargs): | |
if settings.dace_orchestration is not None: | |
if "dace" not in settings.backend.name.lower(): | |
raise ValueError( | |
"DaCe Orchestration works only with DaCe backends. Change the backend to a DaCe supported one." | |
) | |
if method: | |
# self is used to retrieve the _exchange object -on the fly halo exchanges- and the grid object -offset providers- | |
self = args[0] | |
self_name = next(iter(inspect.signature(fuse_func).parameters)) | |
else: | |
raise ValueError( | |
"The orchestration decorator is only for methods -at least for now-." | |
) | |
fuse_func_orig_annotations = copy.deepcopy(fuse_func.__annotations__) | |
fuse_func.__annotations__ = to_dace_annotations( | |
fuse_func | |
) # every arg/kwarg is annotated with DaCe data types | |
exchange_obj = None | |
grid = None | |
for attr_name, attr_value in self.__dict__.items(): | |
if isinstance(attr_value, decomposition.ExchangeRuntime): | |
exchange_obj = getattr(self, attr_name) | |
if isinstance(attr_value, icon_grid.IconGrid): | |
grid = getattr(self, attr_name) | |
if not grid: | |
raise ValueError("No grid object found.") | |
order_kwargs_by_annotations(fuse_func, kwargs) | |
compile_time_args_kwargs = {} | |
all_args_kwargs = [*args, *kwargs.values()] | |
for i, (k, v) in enumerate(fuse_func.__annotations__.items()): | |
if v is dace.compiletime: | |
compile_time_args_kwargs[k] = all_args_kwargs[i] | |
unique_id = make_uid(fuse_func, compile_time_args_kwargs, exchange_obj) | |
default_build_folder = Path(".dacecache") / f"uid_{unique_id}" | |
parse_compile_cache_sdfg( | |
unique_id, | |
compiled_sdfgs, | |
default_build_folder, | |
exchange_obj, | |
fuse_func, | |
compile_time_args_kwargs, | |
self_name, | |
simplify_fused_sdfg=True, | |
) | |
dace_program = compiled_sdfgs[unique_id]["dace_program"] | |
sdfg = compiled_sdfgs[unique_id]["sdfg"] | |
compiled_sdfg = compiled_sdfgs[unique_id]["compiled_sdfg"] | |
# update the args/kwargs with runtime related values, such as | |
# concretized symbols, runtime connectivity tables, GHEX C++ pointers, and DaCe structures pointers | |
updated_args, updated_kwargs = mod_xargs_for_dace_structures( | |
fuse_func, fuse_func_orig_annotations, args, kwargs | |
) | |
updated_kwargs = { | |
**updated_kwargs, | |
**dace_specific_kwargs(exchange_obj, grid.offset_providers), | |
} | |
updated_kwargs = { | |
**updated_kwargs, | |
**dace_symbols_concretization( | |
grid, fuse_func, fuse_func_orig_annotations, args, kwargs | |
), | |
} | |
# | |
sdfg_args = dace_program._create_sdfg_args(sdfg, updated_args, updated_kwargs) | |
if method: | |
del sdfg_args[self_name] | |
fuse_func.__annotations__ = ( | |
fuse_func_orig_annotations # restore the original annotations | |
) | |
with dace.config.temporary_config(): | |
dace.config.Config.set( | |
"compiler", "allow_view_arguments", value=True | |
) # Allow numpy views as arguments: If true, allows users to call DaCe programs with NumPy views (for example, “A[:,1]” or “w.T”) | |
return compiled_sdfg(**sdfg_args) | |
else: | |
return fuse_func(*args, **kwargs) | |
return wrapper | |
return _decorator(func) if func else _decorator | |
def orchestrate(func: Callable | None = None, *, method: bool | None = None): | |
def _decorator(fuse_func: Callable): | |
if settings.dace_orchestration is not None: | |
if "dace" not in settings.backend.name.lower(): | |
raise ValueError( | |
"DaCe Orchestration works only with DaCe backends. Change the backend to a DaCe supported one." | |
) | |
self_name = next(iter(inspect.signature(fuse_func).parameters)) | |
if method is None: | |
# Assume the provided callable is a method if its first argument is called 'self' | |
method = self_name == "self" | |
if not method: | |
raise ValueError( | |
"The orchestration decorator is only for methods -at least for now-." | |
) | |
local_cache = {} # Caching compiled func versions | |
def wrapper(*args, **kwargs): | |
# self is used to retrieve the _exchange object -on the fly halo exchanges- and the grid object -offset providers- | |
self = args[0] | |
exchange_obj = None | |
grid = None | |
for attr_name, attr_value in self.__dict__.items(): | |
if isinstance(attr_value, decomposition.ExchangeRuntime): | |
exchange_obj = getattr(self, attr_name) | |
elif isinstance(attr_value, icon_grid.IconGrid): | |
grid = getattr(self, attr_name) | |
# Use assert here to allow disabling the check when running in production | |
assert grid is not None, "No grid object found in the call arguments." | |
# Add DaCe data types annotations for all args and kwargs | |
dace_annotations = to_dace_annotations(fuse_func) | |
# To extract the actual values from the function parameters defined as compile-time, | |
# we first need to sort the run-time arguments according to their definition | |
# order and also adding `None`s for the missing ones to make use we don't use | |
# the wrong one by mistake. | |
ordered_kwargs = [kwargs.get(key, None) for key in dace_annotations] | |
all_args = [*args, *ordered_kwargs] | |
compile_time_args_kwargs = { | |
arg | |
for arg, (k, v) in zip(all_args, dace_annotations.items(), strict=True) | |
if v is dace.compiletime | |
} | |
unique_id = make_uid(fuse_func, compile_time_args_kwargs, exchange_obj) | |
if (cache_item := local_cache.get(unique_id, None)) is None: | |
fuse_func_orig_annotations = fuse_func.__annotations__ | |
fuse_func.__annotations__ = dace_annotations | |
default_build_folder = Path(".dacecache") / f"uid_{unique_id}" | |
cache_item = local_cache[unique_id] = parse_compile_cache_sdfg( | |
default_build_folder, | |
exchange_obj, | |
fuse_func, | |
compile_time_args_kwargs, | |
self_name, | |
simplify_fused_sdfg=True, | |
) | |
dace_program = cache_item["dace_program"] | |
sdfg = cache_item["sdfg"] | |
compiled_sdfg = cache_item["compiled_sdfg"] | |
# update the args/kwargs with runtime related values, such as | |
# concretized symbols, runtime connectivity tables, GHEX C++ pointers, and DaCe structures pointers | |
updated_args, updated_kwargs = mod_xargs_for_dace_structures( | |
fuse_func, fuse_func_orig_annotations, args, kwargs | |
) | |
updated_kwargs = { | |
**updated_kwargs, | |
**dace_specific_kwargs(exchange_obj, grid.offset_providers), | |
} | |
updated_kwargs = { | |
**updated_kwargs, | |
**dace_symbols_concretization( | |
grid, fuse_func, fuse_func_orig_annotations, args, kwargs | |
), | |
} | |
# | |
sdfg_args = dace_program._create_sdfg_args(sdfg, updated_args, updated_kwargs) | |
if method: | |
del sdfg_args[self_name] | |
fuse_func.__annotations__ = fuse_func_orig_annotations | |
with dace.config.temporary_config(): | |
dace.config.Config.set( | |
"compiler", "allow_view_arguments", value=True | |
) # Allow numpy views as arguments: If true, allows users to call DaCe programs with NumPy views (for example, “A[:,1]” or “w.T”) | |
return compiled_sdfg(**sdfg_args) | |
return wrapper | |
else: | |
return fuse_func | |
return _decorator(func) if func else _decorator |
# Add DaCe data types annotations for **all args and kwargs** | ||
dace_annotations = to_dace_annotations(fuse_func) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was my mistake, but this should be moved out of the wrapper to the decorator, since it only needs to happen once at decoration time.
def generate_orchestration_uid( | ||
obj: Any, obj_name: str = "", members_to_disregard: tuple[str] = () | ||
) -> str: | ||
"""Generate a unique id for a runtime object. | ||
|
||
The unique id is generated by creating a dictionary that describes the runtime state of the object. | ||
For primitive types, the dictionary contains the type and the value. | ||
For arrays, the dictionary contains the shape and the dtype -not their content-. | ||
|
||
Keep in mind that this function is not supposed to be generic, but it is used only for the DaCe orchestration purposes. | ||
""" | ||
primitive_dtypes = (*orchestration_dtypes.ICON4PY_PRIMITIVE_DTYPES, str, uuid.UUID, np.dtype) | ||
|
||
unique_dict = {} | ||
|
||
def _populate_entry(key: str, value: Any, parent_key: str = ""): | ||
full_key = f"{parent_key}.{key}" if parent_key else key | ||
|
||
if full_key in members_to_disregard: | ||
return | ||
|
||
if isinstance(value, primitive_dtypes): | ||
unique_dict[full_key] = {"type": "primitive_dtypes", "value": str(value)} | ||
elif isinstance(value, (np.ndarray, gtx.Field)): | ||
unique_dict[full_key] = { | ||
"type": "array/field", | ||
"shape": str(value.shape), | ||
"dtype": str(value.dtype), | ||
} | ||
elif isinstance(value, (list, tuple)): | ||
if all(isinstance(i, primitive_dtypes) for i in value): | ||
unique_dict[full_key] = { | ||
"type": f"array-like[{'empty' if len(value) == 0 else type(value[0])}]", | ||
"length": str(len(value)), | ||
} | ||
else: | ||
for i, v in enumerate(value): | ||
_populate_entry(str(i), v, full_key) | ||
elif value is None: | ||
unique_dict[full_key] = {"type": "None", "value": "None"} | ||
elif hasattr(value, "__dict__") or isinstance(value, dict): | ||
_populate_unique_dict(value, full_key) | ||
else: | ||
raise ValueError(f"Type {type(value)} is not supported.") | ||
|
||
def _populate_unique_dict(obj: Any, parent_key: str = ""): | ||
if (hasattr(obj, "__dict__") or isinstance(obj, dict)) and not isinstance( | ||
obj, decomposition.ExchangeRuntime | ||
): | ||
obj_to_traverse = obj.__dict__ if hasattr(obj, "__dict__") else obj | ||
for key, value in obj_to_traverse.items(): | ||
_populate_entry(key, value, parent_key) | ||
|
||
if hasattr(obj, "__dict__") or isinstance(obj, dict): | ||
_populate_unique_dict(obj) | ||
else: | ||
_populate_entry(obj_name, obj) | ||
|
||
return uid_from_hashlib(str(unique_dict)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The names of this function and the previous one do not follow a similar naming scheme despite they are closely related, I would try to rename them. Additionally, both functions are lacking the arg description in their docstrings and this one seems more convoluted than needed. I would suggest to refactor it to something with a single internal helper function like:
def generate_orchestration_uid( | |
obj: Any, obj_name: str = "", members_to_disregard: tuple[str] = () | |
) -> str: | |
"""Generate a unique id for a runtime object. | |
The unique id is generated by creating a dictionary that describes the runtime state of the object. | |
For primitive types, the dictionary contains the type and the value. | |
For arrays, the dictionary contains the shape and the dtype -not their content-. | |
Keep in mind that this function is not supposed to be generic, but it is used only for the DaCe orchestration purposes. | |
""" | |
primitive_dtypes = (*orchestration_dtypes.ICON4PY_PRIMITIVE_DTYPES, str, uuid.UUID, np.dtype) | |
unique_dict = {} | |
def _populate_entry(key: str, value: Any, parent_key: str = ""): | |
full_key = f"{parent_key}.{key}" if parent_key else key | |
if full_key in members_to_disregard: | |
return | |
if isinstance(value, primitive_dtypes): | |
unique_dict[full_key] = {"type": "primitive_dtypes", "value": str(value)} | |
elif isinstance(value, (np.ndarray, gtx.Field)): | |
unique_dict[full_key] = { | |
"type": "array/field", | |
"shape": str(value.shape), | |
"dtype": str(value.dtype), | |
} | |
elif isinstance(value, (list, tuple)): | |
if all(isinstance(i, primitive_dtypes) for i in value): | |
unique_dict[full_key] = { | |
"type": f"array-like[{'empty' if len(value) == 0 else type(value[0])}]", | |
"length": str(len(value)), | |
} | |
else: | |
for i, v in enumerate(value): | |
_populate_entry(str(i), v, full_key) | |
elif value is None: | |
unique_dict[full_key] = {"type": "None", "value": "None"} | |
elif hasattr(value, "__dict__") or isinstance(value, dict): | |
_populate_unique_dict(value, full_key) | |
else: | |
raise ValueError(f"Type {type(value)} is not supported.") | |
def _populate_unique_dict(obj: Any, parent_key: str = ""): | |
if (hasattr(obj, "__dict__") or isinstance(obj, dict)) and not isinstance( | |
obj, decomposition.ExchangeRuntime | |
): | |
obj_to_traverse = obj.__dict__ if hasattr(obj, "__dict__") else obj | |
for key, value in obj_to_traverse.items(): | |
_populate_entry(key, value, parent_key) | |
if hasattr(obj, "__dict__") or isinstance(obj, dict): | |
_populate_unique_dict(obj) | |
else: | |
_populate_entry(obj_name, obj) | |
return uid_from_hashlib(str(unique_dict)) | |
def generate_orchestration_uid( | |
obj: Any, obj_name: str = "", members_to_disregard: tuple[str] = () | |
) -> str: | |
"""Generate a unique id for a runtime object. | |
The unique id is generated by creating a dictionary that describes the runtime state of the object. | |
For primitive types, the dictionary contains the type and the value. | |
For arrays, the dictionary contains the shape and the dtype -not their content-. | |
Keep in mind that this function is not supposed to be generic, and should only be used for | |
DaCe orchestration purposes. | |
Args: | |
obj: | |
obj_name: | |
members_to_disregard: | |
""" | |
primitive_dtypes = (*orchestration_dtypes.ICON4PY_PRIMITIVE_DTYPES, str, uuid.UUID, np.dtype) | |
static_data = {} | |
def _populate_entry(key: str, value: Any) -> None: | |
if key in members_to_disregard: | |
return | |
if isinstance(value, primitive_dtypes): | |
static_data[key] = {"type": "primitive_dtypes", "value": str(value)} | |
elif isinstance(value, (np.ndarray, gtx.Field)): | |
static_data[key] = { | |
"type": "array/field", | |
"shape": str(value.shape), | |
"dtype": str(value.dtype), | |
} | |
elif isinstance(value, (list, tuple)): | |
item_types = set(type(i) for i in value) or {None} | |
if len(item_types) == 1 and issubclass( | |
prim_type := item_types.pop(), (*primitive_dtypes, None) | |
): | |
static_data[key] = { | |
"type": f"array-like[{prim_type!s}]", | |
"length": str(len(value)), | |
} | |
else: | |
for child_key, child_value in enumerate(value): | |
_populate_entry(f"{key}.{child_key!s}", child_value) | |
elif isinstance(value, decomposition.ExchangeRuntime): | |
pass | |
elif value is None: | |
static_data[key] = {"type": "None", "value": "None"} | |
elif isinstance(value, dict) or getattr(obj, "__dict__", None): | |
for child_key, child_value in getattr(obj, "__dict__", obj).items(): | |
_populate_entry(f"{key}.{child_key!s}", child_value) | |
else: | |
raise ValueError(f"Type {type(value)} is not supported.") | |
_populate_entry(obj_name, obj) | |
return uid_from_hashlib(str(sorted(static_data.items(), key=operator.itemgetter(0)))) |
Also note that the final dict with the static information should be sorted in a consistent way to avoid false negatives if the items are created/traversed in a different order.
def wrapper(*args, **kwargs): | ||
self = args[0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not simply
def wrapper(*args, **kwargs): | |
self = args[0] | |
def wrapper(self, *args, **kwargs): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The orchestrator could be used for non member functions as well -not yet implemented, but possibly in the future-. So, this is the reason for not having self in the arg list.
Mandatory Tests Please make sure you run these tests via comment before you merge!
Optional Tests To run benchmarks you can use:
To run tests and benchmarks with the DaCe backend you can use:
In case your change might affect downstream icon-exclaim, please consider running
For more detailed information please look at CI in the EXCLAIM universe. |
cscs-ci run dace |
This PR concerns only the DaCe backend.
Currently, the Diffusion granule calls the various stencils one after the other (in
_do_diffusion_step
). This means that for every stencil, there is a separate SDFG, and consequently one stencil is not aware of the others (this behavior limits the analyzability of the full granule). This PR introduces a decorator that fuses all these SDFGs under one compilation unit, allowing DaCe for further analysis and optimizations. Placing a GT4Py program inside a DaCe program region, and extracting the underlying SDFG, is possible due to this GT4Py PR.The halo exchanges are also taken care from the DaCe orchestrator. A follow-up PR will introduce automated halo exchanges. Currently, the halo exchange class implements the SDFGConvertible interface like GT4Py Program.
The orchestrator is activated either through an env var
ICON4PY_DACE_ORCHESTRATION=AnyValue
or through this pytest option--dace-orchestration=AnyValue
.The orchestrator suppports ahead of time compilation, however given that DaCe does not support nested Structures -like self in the Diffusion class-, some of the arguments need to be provided at compile time, through
dace.compiletime
. This annotation means that the correspoding argument will be considered from the closure of the function.The orchestrator provides full caching support and takes into consideration when a dace.compiletime argument changes, with a subsequent re-compilation of the fused SDFG.